{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZZxhwHjOvnWF"
      },
      "source": [
        "Loading factorizations found by AlphaTensor and recombination.\n",
        "\n",
        "- Copyright 2022 DeepMind Technologies Limited\n",
        "- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0\n",
        "- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode\n",
        "- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.\n",
        "- This is not an official Google product."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lv32k_zmYXta"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from google.colab import files"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fpWZeG2V3ZV0"
      },
      "source": [
        "Upload one of the two files provided in the same folder: `factorization_r.npz` (algorithms in standard arithmetic) or `factorization_f2.npz` (algorithms in arithmetic modulo 2)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FrXikWaPYO1n"
      },
      "outputs": [],
      "source": [
        "uploaded = files.upload()\n",
        "filename = list(uploaded.keys())[0]\n",
        "with open(filename, 'rb') as f:\n",
        "  factorizations = dict(np.load(f, allow_pickle=True))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AKiXSgk0YVRn"
      },
      "outputs": [],
      "source": [
        "# Print available factorizations and their shapes.\n",
        "for key in factorizations:\n",
        "  u, v, w = factorizations[key]\n",
        "  rank = u.shape[-1]\n",
        "  assert rank == v.shape[-1] and rank == w.shape[-1]\n",
        "  print(f'{key}: rank={u.shape[-1]}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PDbovsIXC8-H"
      },
      "source": [
        "Please note that as provided, the factorizations decompose the *symmetrized* version of the matrix multiplication tensor, representing the bilinear operation $\\mathbf{A}, \\mathbf{B} \\mapsto (\\mathbf{A} \\cdot \\mathbf{B})^T$. This is standard in the literature, and factorizations can be easily converted\n",
        "between the symmetrized and non-symmetrized versions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WZK9ReTpbGUu"
      },
      "outputs": [],
      "source": [
        "def get_mamu_tensor_rectangular(a: int, b: int, c: int) -\u003e np.ndarray:\n",
        "  \"\"\"Returns the symmetrized matrix multiplication tensor T_{a, b, c}.\"\"\"\n",
        "  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)\n",
        "  for i in range(a):\n",
        "    for j in range(b):\n",
        "      for k in range(c):\n",
        "        result[i * b  + j][j * c + k][k * a + i] = 1\n",
        "  return result\n",
        "\n",
        "\n",
        "# Test correctness of a factorization.\n",
        "tensor = get_mamu_tensor_rectangular(3, 4, 5)\n",
        "u, v, w = factorizations['3,4,5']\n",
        "reconstruction = np.einsum('ir,jr,kr-\u003eijk', u, v, w)\n",
        "if np.array_equal(tensor, reconstruction):\n",
        "  print('Factorization is correct in R (standard arithmetic).')\n",
        "elif np.array_equal(tensor, np.mod(reconstruction, 2)):\n",
        "  print('Factorization is correct in F2 (modular arithmetic).')\n",
        "else:\n",
        "  print('Factorization is incorrect.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "75ozbYv50Aeg"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Explore factorizations",
      "private_outputs": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
